{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial3: GAT implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Implementation of GAT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Official resources:\n",
    "* [Code](https://dsgiitr.com/blogs/gat/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.11.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/antonio/anaconda3/envs/geometric_new/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['TORCH'] = torch.__version__\n",
    "print(torch.__version__)\n",
    "\n",
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GATLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    Simple PyTorch Implementation of the Graph Attention layer.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super(GATLayer, self).__init__()\n",
    "      \n",
    "    def forward(self, input, adj):\n",
    "        print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's start from the forward method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Linear Transformation\n",
    "\n",
    "$$\n",
    "\\bar{h'}_i = \\textbf{W}\\cdot \\bar{h}_i\n",
    "$$\n",
    "with $\\textbf{W}\\in\\mathbb R^{F'\\times F}$ and $\\bar{h}_i\\in\\mathbb R^{F}$.\n",
    "\n",
    "$$\n",
    "\\bar{h'}_i \\in \\mathbb{R}^{F'}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 2])\n"
     ]
    }
   ],
   "source": [
    "in_features = 5\n",
    "out_features = 2\n",
    "nb_nodes = 3\n",
    "\n",
    "W = nn.Parameter(torch.zeros(size=(in_features, out_features))) #xavier paramiter inizializator\n",
    "nn.init.xavier_uniform_(W.data, gain=1.414)\n",
    "\n",
    "input = torch.rand(nb_nodes,in_features) \n",
    "\n",
    "\n",
    "# linear transformation\n",
    "h = torch.mm(input, W)\n",
    "N = h.size()[0]\n",
    "\n",
    "print(h.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attention Mechanism"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](AttentionMechanism.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 1])\n"
     ]
    }
   ],
   "source": [
    "a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) #xavier paramiter inizializator\n",
    "nn.init.xavier_uniform_(a.data, gain=1.414)\n",
    "print(a.shape)\n",
    "\n",
    "leakyrelu = nn.LeakyReLU(0.2)  # LeakyReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * out_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](a_input.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "e = leakyrelu(torch.matmul(a_input, a).squeeze(2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 3, 4]) torch.Size([4, 1])\n",
      "\n",
      "torch.Size([3, 3, 1])\n",
      "\n",
      "torch.Size([3, 3])\n"
     ]
    }
   ],
   "source": [
    "print(a_input.shape,a.shape)\n",
    "print(\"\")\n",
    "print(torch.matmul(a_input,a).shape)\n",
    "print(\"\")\n",
    "print(torch.matmul(a_input,a).squeeze(2).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Masked Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 3])\n"
     ]
    }
   ],
   "source": [
    "# Masked Attention\n",
    "adj = torch.randint(2, (3, 3))\n",
    "\n",
    "zero_vec  = -9e15*torch.ones_like(e)\n",
    "print(zero_vec.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 1, 0],\n",
      "        [1, 0, 1],\n",
      "        [0, 0, 0]]) \n",
      " tensor([[0.6686, 1.1267, 1.0171],\n",
      "        [0.8126, 1.2706, 1.1611],\n",
      "        [0.7363, 1.1943, 1.0848]], grad_fn=<LeakyReluBackward0>) \n",
      " tensor([[-9.0000e+15, -9.0000e+15, -9.0000e+15],\n",
      "        [-9.0000e+15, -9.0000e+15, -9.0000e+15],\n",
      "        [-9.0000e+15, -9.0000e+15, -9.0000e+15]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-9.0000e+15,  1.1267e+00, -9.0000e+15],\n",
       "        [ 8.1261e-01, -9.0000e+15,  1.1611e+00],\n",
       "        [-9.0000e+15, -9.0000e+15, -9.0000e+15]], grad_fn=<SWhereBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attention = torch.where(adj > 0, e, zero_vec)\n",
    "print(adj,\"\\n\",e,\"\\n\",zero_vec)\n",
    "attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "attention = F.softmax(attention, dim=1)\n",
    "h_prime   = torch.matmul(attention, h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 1.0000, 0.0000],\n",
       "        [0.4138, 0.0000, 0.5862],\n",
       "        [0.3333, 0.3333, 0.3333]], grad_fn=<SoftmaxBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0072,  1.0209],\n",
       "        [-0.1147,  0.8389],\n",
       "        [-0.0802,  0.8856]], grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h_prime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### h_prime vs h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0072,  1.0209],\n",
      "        [-0.1147,  0.8389],\n",
      "        [-0.0802,  0.8856]], grad_fn=<MmBackward0>) \n",
      " tensor([[-0.1773,  0.6968],\n",
      "        [ 0.0072,  1.0209],\n",
      "        [-0.0706,  0.9392]], grad_fn=<MmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(h_prime,\"\\n\",h)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build the layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GATLayer(nn.Module):\n",
    "    def __init__(self, in_features, out_features, dropout, alpha, concat=True):\n",
    "        super(GATLayer, self).__init__()\n",
    "        \n",
    "        '''\n",
    "        TODO\n",
    "        '''\n",
    "        \n",
    "    def forward(self, input, adj):\n",
    "        # Linear Transformation\n",
    "        h = torch.mm(input, self.W) # matrix multiplication\n",
    "        N = h.size()[0]\n",
    "\n",
    "        # Attention Mechanism\n",
    "        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)\n",
    "        e       = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))\n",
    "\n",
    "        # Masked Attention\n",
    "        zero_vec  = -9e15*torch.ones_like(e)\n",
    "        attention = torch.where(adj > 0, e, zero_vec)\n",
    "        \n",
    "        attention = F.softmax(attention, dim=1)\n",
    "        attention = F.dropout(attention, self.dropout, training=self.training)\n",
    "        h_prime   = torch.matmul(attention, h)\n",
    "\n",
    "        if self.concat:\n",
    "            return F.elu(h_prime)\n",
    "        else:\n",
    "            return h_prime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GATLayer(nn.Module):\n",
    "    def __init__(self, in_features, out_features, dropout, alpha, concat=True):\n",
    "        super(GATLayer, self).__init__()\n",
    "        self.dropout       = dropout        # drop prob = 0.6\n",
    "        self.in_features   = in_features    # \n",
    "        self.out_features  = out_features   # \n",
    "        self.alpha         = alpha          # LeakyReLU with negative input slope, alpha = 0.2\n",
    "        self.concat        = concat         # conacat = True for all layers except the output layer.\n",
    "\n",
    "        \n",
    "        # Xavier Initialization of Weights\n",
    "        # Alternatively use weights_init to apply weights of choice \n",
    "        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))\n",
    "        nn.init.xavier_uniform_(self.W.data, gain=1.414)\n",
    "        \n",
    "        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))\n",
    "        nn.init.xavier_uniform_(self.a.data, gain=1.414)\n",
    "        \n",
    "        # LeakyReLU\n",
    "        self.leakyrelu = nn.LeakyReLU(self.alpha)\n",
    "\n",
    "    def forward(self, input, adj):\n",
    "        # Linear Transformation\n",
    "        h = torch.mm(input, self.W) # matrix multiplication\n",
    "        N = h.size()[0]\n",
    "        print(N)\n",
    "\n",
    "        # Attention Mechanism\n",
    "        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)\n",
    "        e       = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))\n",
    "\n",
    "        # Masked Attention\n",
    "        zero_vec  = -9e15*torch.ones_like(e)\n",
    "        attention = torch.where(adj > 0, e, zero_vec)\n",
    "        \n",
    "        attention = F.softmax(attention, dim=1)\n",
    "        attention = F.dropout(attention, self.dropout, training=self.training)\n",
    "        h_prime   = torch.matmul(attention, h)\n",
    "\n",
    "        if self.concat:\n",
    "            return F.elu(h_prime)\n",
    "        else:\n",
    "            return h_prime"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of Classes in Cora: 7\n",
      "Number of Node Features in Cora: 1433\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.data import Data\n",
    "from torch_geometric.nn import GATConv\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "name_data = 'Cora'\n",
    "dataset = Planetoid(root= '/tmp/' + name_data, name = name_data)\n",
    "dataset.transform = T.NormalizeFeatures()\n",
    "\n",
    "print(f\"Number of Classes in {name_data}:\", dataset.num_classes)\n",
    "print(f\"Number of Node Features in {name_data}:\", dataset.num_node_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.9471, grad_fn=<NllLossBackward0>)\n",
      "tensor(0.7560, grad_fn=<NllLossBackward0>)\n",
      "tensor(0.5531, grad_fn=<NllLossBackward0>)\n",
      "tensor(0.5089, grad_fn=<NllLossBackward0>)\n",
      "tensor(0.5936, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "class GAT(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GAT, self).__init__()\n",
    "        self.hid = 8\n",
    "        self.in_head = 8\n",
    "        self.out_head = 1\n",
    "        \n",
    "        \n",
    "        self.conv1 = GATConv(dataset.num_features, self.hid, heads=self.in_head, dropout=0.6)\n",
    "        self.conv2 = GATConv(self.hid*self.in_head, dataset.num_classes, concat=False,\n",
    "                             heads=self.out_head, dropout=0.6)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "                \n",
    "        x = F.dropout(x, p=0.6, training=self.training)\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.elu(x)\n",
    "        x = F.dropout(x, p=0.6, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        \n",
    "        return F.log_softmax(x, dim=1)\n",
    "    \n",
    "    \n",
    "    \n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device = \"cpu\"\n",
    "\n",
    "model = GAT().to(device)\n",
    "data = dataset[0].to(device)\n",
    "\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(1000):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data)\n",
    "    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n",
    "    \n",
    "    if epoch%200 == 0:\n",
    "        print(loss)\n",
    "    \n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8210\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "_, pred = model(data).max(dim=1)\n",
    "correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())\n",
    "acc = correct / data.test_mask.sum().item()\n",
    "print('Accuracy: {:.4f}'.format(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
